envs:
  HF_TOKEN: <your-huggingface-token> # Change to your own huggingface token
  ARTIFACT_BUCKET_NAME: YOUR_OWN_BUCKET_NAME # Change to your own bucket name
  WANDB_API_KEY: "" # Change to your own wandb api key
  MODEL_SIZE: 7
  USE_XFORMERS: 1

resources:
  accelerators: A100-80GB:8
  disk_size: 1024
  use_spot: true

num_nodes: 1

file_mounts:
  /artifacts:
    name: $ARTIFACT_BUCKET_NAME
    mode: MOUNT

workdir: .

setup: |
  # Download the ShareGPT dataset
  # Change to your OWN dataset if you want to train your own model
  mkdir -p $HOME/data
  wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -O $HOME/data/sharegpt.json

  # Setup the environment
  conda activate chatbot
  if [ $? -ne 0 ]; then
    conda create -n chatbot python=3.10 -y
    conda activate chatbot
  fi
  cd ./scripts
  # Use an older version of fastchat to install transformers==4.28.1, as the transformers>=4.31
  # has issues with checkpoint saving -- saving additional large files in the checkpoint folder
  pip install git+https://github.com/lm-sys/FastChat.git@cfc73bf3e13c22ded81e89675e0d7b228cf4b342
  if [ $USE_XFORMERS -eq 1 ]; then
    pip install -U xformers
  fi
  python hardcoded_questions.py
  python -m fastchat.data.merge --in $HOME/data/sharegpt.json hardcoded.json --out $HOME/data/mydata.json

  python -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"

run: |
  cd scripts
  conda activate chatbot
  if [ $USE_XFORMERS -eq 1 ]; then
    TRAIN_SCRIPT=train_xformers.py
  else
    TRAIN_SCRIPT=train.py
  fi

  PER_DEVICE_BATCH_SIZE=4
  SEQ_LEN=2048
  NUM_NODES=`echo "$SKYPILOT_NODE_IPS" | wc -l`
  HOST_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1`

  # Turn off wandb if no api key is provided
  if [ $WANDB_API_KEY == "" ]; then
    WANDB_MODE="offline"
  fi
  
  torchrun \
    --nnodes=$NUM_NODES \
    --nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
    --master_port=12375 \
    --master_addr=$HOST_ADDR \
    --node_rank=${SKYPILOT_NODE_RANK} \
    $TRAIN_SCRIPT \
    --model_name_or_path meta-llama/Llama-2-${MODEL_SIZE}b-hf \
    --data_path $HOME/data/mydata.json \
    --bf16 True \
    --output_dir /artifacts/chatbot/${MODEL_SIZE}b \
    --num_train_epochs 3 \
    --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
    --per_device_eval_batch_size $PER_DEVICE_BATCH_SIZE \
    --gradient_accumulation_steps $((128 * 512 / $SEQ_LEN / $PER_DEVICE_BATCH_SIZE / $NUM_NODES / $SKYPILOT_NUM_GPUS_PER_NODE)) \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 600 \
    --save_total_limit 10 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --fsdp "full_shard auto_wrap" \
    --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
    --tf32 True \
    --model_max_length ${SEQ_LEN} \
    --run_name $SKYPILOT_TASK_ID \
    --gradient_checkpointing True \
    --lazy_preprocess True

  returncode=$?
  exit $returncode


